Skip to content

[tx] General implementation of trainable Hyper Connections#1008

Open
tanmaysachan wants to merge 8 commits intoNovaSky-AI:mainfrom
tanmaysachan:tanmay/mhc
Open

[tx] General implementation of trainable Hyper Connections#1008
tanmaysachan wants to merge 8 commits intoNovaSky-AI:mainfrom
tanmaysachan:tanmay/mhc

Conversation

@tanmaysachan
Copy link
Contributor

@tanmaysachan tanmaysachan commented Feb 2, 2026

Addresses #952

This PR is a general implementation of Hyper connections.

This is supposed to be an extension like Lora, where the default case mimics a standard residual connection with identity mappings.

Default case - Trainable is false. Expansion rate is 1.

  1. H_res is a single value matrix [1]
  2. H_pre and H_post are vectors of [1, 1, 1, ...] that result in no-op matmuls

For expansion rate > 1

  1. H_res is initialized as identity of size nxn (n is the expansion rate)
  2. H_pre is [1/n, 1/n, ...]
  3. H_post is [1, 1, 1, ...]

These matrices preserve identity mapping. So expansion rate > 1 but untrainable still results in the the same outputs.

Todos

  • simplify rms integration - added elementwise_affine as a flag
  • Benchmark/ensure no regression for expansion_rate = 1 - minimal difference in step time when expansion rate is 1 and untrainable.

Future work

  • Fine tune on custom data with mHC + LoRA to see perf gains

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a general implementation of Hyper Connections as an extension to the transformer layers. The changes are mainly in tx/layers/connectors.py where the Connector module is defined, and in tx/models/deepseekv3.py to integrate it into the decoder layers.

My review found a couple of issues:

  • An unused trainable parameter in the Connector class which should be removed for clarity.
  • A bug in DeepseekV3Model when handling intermediate hidden states for expansion_rate > 1, where squeeze() is used incorrectly.

Overall, the implementation of the Hyper Connections logic seems to follow the intended pattern of pre/post processing around existing attention and MLP blocks. The changes are well-contained. Addressing the mentioned points will improve the robustness and clarity of the implementation.

for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
all_hidden_states.append(hidden_states.squeeze())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

hidden_states.squeeze() is used here to process intermediate hidden states. This will only work correctly if expansion_rate is 1. For expansion_rate > 1, squeeze() will have no effect because the expansion dimension has size n > 1. This will result in appending a tensor with an incorrect shape (..., n, C) to all_hidden_states, which is inconsistent with other states and likely to cause issues downstream.

A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.

Suggested change
all_hidden_states.append(hidden_states.squeeze())
all_hidden_states.append(hidden_states.mean(axis=-2))

hidden_dim: int,
expansion_rate: int,
*,
trainable: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The trainable parameter is defined but it is not used anywhere in the Connector class. This could be misleading for developers using this module. Consider removing it from the method signature, and also the assignment self.trainable = trainable on line 27, to improve code clarity.

@pcmoritz pcmoritz added the tx label Feb 2, 2026
self.eps = eps
self.weight = Param(
size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs
size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary, testing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcmoritz
Copy link
Collaborator

pcmoritz commented Feb 5, 2026

This looks very elegant, thanks a lot for putting it together! Have you tried to do any end-to-end runs yet / studied the performance, both in terms of learning dynamics / accuracy, as well as how much slowdown it incurs :)

@tanmaysachan
Copy link
Contributor Author

Just waiting for the weekend to give it a spin 😅

I'll give Qwen0.6B a shot on an A/H100

@pcmoritz
Copy link
Collaborator

pcmoritz commented Feb 5, 2026

Sounds great! I'm putting together the 0.3.0 release at the moment, so it will probably need to wait then, but 0.3.1 should come relatively soon thereafter, so it is not a problem. I'll put a callout in the release blog anyways, if somebody wants to try it out, they can just apply the diff themselves given how simple this is :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants